{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 方案：DQN + TFIDF + Cosine + tri-gram\n",
    "\n",
    "---\n",
    "\n",
    "step 1. 构造QQ+A模式的多轮训练集D\n",
    "\n",
    "step 2. 在训练集D上面训练上sentence2vec模型、TFIDF-IR模型，以tri-gram作为文本特征\n",
    "\n",
    "step 3. 训练一个DQN模型从TFIDF-IR模型给出的候选答案中提取使得BLEU得分最大的答案"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from tqdm import tqdm\n",
    "import sys\n",
    "import pprint\n",
    "\n",
    "sys.path.insert(0, \"/home/team55/notespace/zengbin\")\n",
    "\n",
    "from jddc.config import DQNConfig\n",
    "from jddc.embedding import load_s2v_model\n",
    "from jddc.tfidf import load_tfidf_ir_model\n",
    "import jddc.utils as u\n",
    "from jddc.dqn import *\n",
    "\n",
    "from torch.autograd import Variable"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "conf = DQNConfig()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载数据、sentence2vec、TFIDF-IR\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sessions = u.read_from_pkl(conf.pkl_mqa_1000)\n",
    "s2v_model = load_s2v_model()\n",
    "ir_model = load_tfidf_ir_model()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## DQN训练\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = DQNTrainer(sessions, s2v_model, ir_model, use_cuda=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.run()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python3.6",
   "language": "python",
   "name": "python3.6"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
